# -*- coding: utf-8 -*-
import os
from apps.dlt.models import model
from dataset.dlt_dataset import LottoDataSet
from config.settings import DLTSettings

# 初始化数据集
lotto_dataset = LottoDataSet(train_data_rate=1)
# 创建保存权重的文件夹
if not os.path.exists(DLTSettings.CHECKPOINTS_PATH):
    os.mkdir(DLTSettings.CHECKPOINTS_PATH)
# 开始训练
model.fit(lotto_dataset.train_np_x, lotto_dataset.train_np_y, batch_size=DLTSettings.BATCH_SIZE,
          epochs=DLTSettings.EPOCHS)
# 保存模型
model.save_weights('{}/model_checkpoint_x'.format(DLTSettings.CHECKPOINTS_PATH))
